{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\" width=\"90px\">\n",
    "\n",
    "# PySpark LLM Inference: Qwen-2.5 Text Summarization\n",
    "\n",
    "In this notebook, we demonstrate distributed batch inference with [Qwen-2.5](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct), using open weights on Huggingface.\n",
    "\n",
    "The Qwen-2.5-7b-instruct is an instruction-fine-tuned version of the Qwen-2.5-7b base model. We'll show how to use the model to perform text summarization.\n",
    "\n",
    "**Note:** Running this model on GPU with 16-bit precision requires **~16GB** of GPU RAM. Make sure your instances have sufficient GPU capacity."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The dataset we'll use requires Zstandard compression."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting zstandard\n",
      "  Downloading zstandard-0.23.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)\n",
      "Downloading zstandard-0.23.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.4 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.4/5.4 MB\u001b[0m \u001b[31m66.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hInstalling collected packages: zstandard\n",
      "Successfully installed zstandard-0.23.0\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install zstandard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\n",
    "# See (https://github.com/huggingface/transformers/issues/5486) for more info. \n",
    "import os\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check the cluster environment to handle any platform-specific configurations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n",
    "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n",
    "on_standalone = not (on_databricks or on_dataproc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For cloud environments, load the model to the distributed file system.\n",
    "if on_databricks:\n",
    "    models_dir = \"/dbfs/FileStore/spark-dl-models\"\n",
    "    dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n",
    "    model_path = f\"{models_dir}/qwen-2.5-7b\"\n",
    "elif on_dataproc:\n",
    "    models_dir = \"/mnt/gcs/spark-dl-models\"\n",
    "    os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n",
    "    model_path = f\"{models_dir}/qwen-2.5-7b\"\n",
    "else:\n",
    "    model_path = os.path.abspath(\"qwen-2.5-7b\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download the model from huggingface hub."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import snapshot_download\n",
    "\n",
    "model_path = snapshot_download(\n",
    "    repo_id=\"Qwen/Qwen2.5-7B-Instruct\",\n",
    "    local_dir=model_path\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Warmup: Running locally\n",
    "\n",
    "**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "352b738e1a2442b0a997467aaf6eb0ad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_path,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    device_map=\"auto\"\n",
    ")\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side=\"left\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "system_prompt = {\n",
    "    \"role\": \"system\",\n",
    "    \"content\": \"You are a knowledgeable AI assistant that provides accurate answers to questions.\"\n",
    "}\n",
    "\n",
    "queries = [\n",
    "    \"How many vowels are in 'elephant'?\",\n",
    "    \"What is the square root of 16?\",\n",
    "    \"How many planets are in our solar system?\"\n",
    "]\n",
    "\n",
    "prompts = [\n",
    "    [\n",
    "        system_prompt,\n",
    "        {\"role\": \"user\", \"content\": query}\n",
    "    ] for query in queries\n",
    "]\n",
    "\n",
    "text = tokenizer.apply_chat_template(\n",
    "    prompts,\n",
    "    tokenize=False,\n",
    "    add_generation_prompt=True,\n",
    ")\n",
    "\n",
    "model_inputs = tokenizer(text, return_tensors=\"pt\", padding=True).to(model.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_ids = model.generate(\n",
    "    **model_inputs,\n",
    "    max_new_tokens=256,\n",
    ")\n",
    "\n",
    "outputs = tokenizer.batch_decode(generated_ids[:, model_inputs.input_ids.shape[1]:], skip_special_tokens = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Q: How many vowels are in 'elephant'?\n",
      "A: The word \"elephant\" contains 3 vowels. The vowels are 'e', 'e', and 'a'.\n",
      "\n",
      "Q: What is the square root of 16?\n",
      "A: The square root of 16 is 4, because \\(4 \\times 4 = 16\\).\n",
      "\n",
      "Q: How many planets are in our solar system?\n",
      "A: There are eight planets in our solar system. They are, in order from the Sun:\n",
      "\n",
      "1. Mercury\n",
      "2. Venus\n",
      "3. Earth\n",
      "4. Mars\n",
      "5. Jupiter\n",
      "6. Saturn\n",
      "7. Uranus\n",
      "8. Neptune\n",
      "\n",
      "Pluto was previously considered the ninth planet but is now classified as a dwarf planet.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for query, output in zip(queries, outputs):\n",
    "    print(f\"Q: {query}\\nA: {output}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "del model\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PySpark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pyspark.sql.types import *\n",
    "from pyspark import SparkConf\n",
    "from pyspark.sql import SparkSession\n",
    "from pyspark.sql.functions import pandas_udf, col, struct, length, lit, concat\n",
    "from pyspark.ml.functions import predict_batch_udf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import datasets\n",
    "from datasets import load_dataset\n",
    "datasets.disable_progress_bars()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Create Spark Session\n",
    "\n",
    "For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \n",
    "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "25/02/16 11:48:57 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
      "25/02/16 11:48:57 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
      "Setting default log level to \"WARN\".\n",
      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
      "25/02/16 11:48:57 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
     ]
    }
   ],
   "source": [
    "conf = SparkConf()\n",
    "\n",
    "if 'spark' not in globals():\n",
    "    if on_standalone:\n",
    "        import socket\n",
    "        conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
    "        hostname = socket.gethostname()\n",
    "        conf.setMaster(f\"spark://{hostname}:7077\")\n",
    "        conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
    "        conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
    "\n",
    "    conf.set(\"spark.executor.cores\", \"8\")\n",
    "    conf.set(\"spark.task.maxFailures\", \"1\")\n",
    "    conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n",
    "    conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n",
    "    conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
    "    conf.set(\"spark.python.worker.reuse\", \"true\")\n",
    "\n",
    "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n",
    "sc = spark.sparkContext"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load and Preprocess DataFrame\n",
    "\n",
    "Load the first 500 samples of the [PUBMED abstracts dataset](https://huggingface.co/datasets/casinca/PUBMED_title_abstracts_2019_baseline) from Huggingface and store in a Spark Dataframe."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "pubmed_dataset = load_dataset(\"casinca/PUBMED_title_abstracts_2019_baseline\", split=\"train\", streaming=True)\n",
    "pubmed_pds = pd.Series([sample[\"text\"] for sample in pubmed_dataset.take(500)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = spark.createDataFrame(pubmed_pds, schema=StringType())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------------------------------------------------------------------------------------------+\n",
      "|                                                                                               value|\n",
      "+----------------------------------------------------------------------------------------------------+\n",
      "|Epidemiology of hypoxaemia in children with acute lower respiratory infection.\\nTo determine the ...|\n",
      "|Clinical signs of hypoxaemia in children with acute lower respiratory infection: indicators of ox...|\n",
      "|Hypoxaemia in children with severe pneumonia in Papua New Guinea.\\nTo investigate the severity an...|\n",
      "|Oxygen concentrators and cylinders.\\nA comparison is made between oxygen cylinders and oxygen con...|\n",
      "|Oxygen supply in rural africa: a personal experience.\\nOxygen is one of the essential medical sup...|\n",
      "+----------------------------------------------------------------------------------------------------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "df.show(5, truncate=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Format each sample into the chat template, including a system prompt to guide generation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "system_prompt = '''You are a knowledgeable AI assistant. Your job is to create a 2-3 sentence summary \n",
    "of a research abstract that captures the main objective, methodology, and key findings, using clear \n",
    "language while preserving technical accuracy and quantitative results.'''\n",
    "\n",
    "df = df.select(\n",
    "    concat(\n",
    "        lit(\"<|im_start|>system\\n\"),\n",
    "        lit(system_prompt),\n",
    "        lit(\"<|im_end|>\\n<|im_start|>user\\n\"),\n",
    "        col(\"value\"),\n",
    "        lit(\"<|im_end|>\\n<|im_start|>assistant\\n\")\n",
    "    ).alias(\"prompt\")\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<|im_start|>system\n",
      "You are a knowledgeable AI assistant. Your job is to create a 2-3 sentence summary \n",
      "of a research abstract that captures the main objective, methodology, and key findings, using clear \n",
      "language while preserving technical accuracy and quantitative results.<|im_end|>\n",
      "<|im_start|>user\n",
      "Epidemiology of hypoxaemia in children with acute lower respiratory infection.\n",
      "To determine the prevalence of hypoxaemia in children aged under 5 years suffering acute lower respiratory infections (ALRI), the risk factors for hypoxaemia in children under 5 years of age with ALRI, and the association of hypoxaemia with an increased risk of dying in children of the same age. Systematic review of the published literature. Out-patient clinics, emergency departments and hospitalisation wards in 23 health centres from 10 countries. Cohort studies reporting the frequency of hypoxaemia in children under 5 years of age with ALRI, and the association between hypoxaemia and the risk of dying. Prevalence of hypoxaemia measured in children with ARI and relative risks for the association between the severity of illness and the frequency of hypoxaemia, and between hypoxaemia and the risk of dying. Seventeen published studies were found that included 4,021 children under 5 with acute respiratory infections (ARI) and reported the prevalence of hypoxaemia. Out-patient children and those with a clinical diagnosis of upper ARI had a low risk of hypoxaemia (pooled estimate of 6% to 9%). The prevalence increased to 31% and to 43% in patients in emergency departments and in cases with clinical pneumonia, respectively, and it was even higher among hospitalised children (47%) and in those with radiographically confirmed pneumonia (72%). The cumulated data also suggest that hypoxaemia is more frequent in children living at high altitude. Three papers reported an association between hypoxaemia and death, with relative risks varying between 1.4 and 4.6. Papers describing predictors of hypoxaemia have focused on clinical signs for detecting hypoxaemia rather than on identifying risk factors for developing this complication. Hypoxaemia is a common and potentially lethal complication of ALRI in children under 5, particularly among those with severe disease and those living at high altitude. Given the observed high prevalence of hypoxaemia and its likely association with increased mortality, efforts should be made to improve the detection of hypoxaemia and to provide oxygen earlier to more children with severe ALRI.<|im_end|>\n",
      "<|im_start|>assistant\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(df.take(1)[0].prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = \"spark-dl-datasets/pubmed_abstracts\"\n",
    "if on_databricks:\n",
    "    dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n",
    "    data_path = \"dbfs:/FileStore/\" + data_path\n",
    "\n",
    "df.write.mode(\"overwrite\").parquet(data_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using Triton Inference Server\n",
    "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \n",
    "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \n",
    "\n",
    "The process looks like this:\n",
    "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n",
    "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n",
    "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n",
    "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n",
    "\n",
    "<img src=\"../images/spark-server.png\" alt=\"drawing\" width=\"700\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Import the helper class from server_utils.py:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.addPyFile(\"server_utils.py\")\n",
    "\n",
    "from server_utils import TritonServerManager"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the Triton Server function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def triton_server(ports, model_path):\n",
    "    import time\n",
    "    import signal\n",
    "    import torch\n",
    "    import numpy as np\n",
    "    from pytriton.decorators import batch\n",
    "    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n",
    "    from pytriton.triton import Triton, TritonConfig\n",
    "    from pyspark import TaskContext\n",
    "    from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "    print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n",
    "    model = AutoModelForCausalLM.from_pretrained(\n",
    "        model_path,\n",
    "        torch_dtype=torch.bfloat16,\n",
    "        device_map=\"auto\"\n",
    "    )\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side=\"left\")\n",
    "\n",
    "    @batch\n",
    "    def _infer_fn(**inputs):\n",
    "        prompts = np.squeeze(inputs[\"prompts\"]).tolist()\n",
    "        print(f\"SERVER: Received batch of size {len(prompts)}\")\n",
    "        decoded_prompts = [p.decode(\"utf-8\") for p in prompts]\n",
    "        tokenized_inputs = tokenizer(decoded_prompts, padding=True, return_tensors=\"pt\").to(model.device)\n",
    "        generated_ids = model.generate(**tokenized_inputs, max_new_tokens=256)\n",
    "        outputs = tokenizer.batch_decode(generated_ids[:, tokenized_inputs.input_ids.shape[1]:], skip_special_tokens = True)\n",
    "        return {\n",
    "            \"outputs\": np.array(outputs).reshape(-1, 1)\n",
    "        }\n",
    "\n",
    "    workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n",
    "    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n",
    "    with Triton(config=triton_conf, workspace=workspace_path) as triton:\n",
    "        triton.bind(\n",
    "            model_name=\"qwen-2.5\",\n",
    "            infer_func=_infer_fn,\n",
    "            inputs=[\n",
    "                Tensor(name=\"prompts\", dtype=object, shape=(-1,)),\n",
    "            ],\n",
    "            outputs=[\n",
    "                Tensor(name=\"outputs\", dtype=object, shape=(-1,)),\n",
    "            ],\n",
    "            config=ModelConfig(\n",
    "                max_batch_size=64,\n",
    "                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\n",
    "            ),\n",
    "            strict=True,\n",
    "        )\n",
    "\n",
    "        def _stop_triton(signum, frame):\n",
    "            print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n",
    "            triton.stop()\n",
    "\n",
    "        signal.signal(signal.SIGTERM, _stop_triton)\n",
    "\n",
    "        print(\"SERVER: Serving inference\")\n",
    "        triton.serve()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Start Triton servers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n",
    "- Find available ports for HTTP/gRPC/metrics\n",
    "- Deploy a server on each node via stage-level scheduling\n",
    "- Gracefully shutdown servers across nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"qwen-2.5\"\n",
    "server_manager = TritonServerManager(model_name=model_name, model_path=model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-16 11:49:25,237 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n",
      "2025-02-16 11:49:25,239 - INFO - Starting 1 servers.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'cb4ae00-lcedt': (3490378, [7000, 7001, 7002])}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n",
    "server_manager.start_servers(triton_server, wait_retries=24)  # allow up to 2 minutes for model loading"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Define client function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the hostname -> url mapping from the server manager:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "host_to_grpc_url = server_manager.host_to_grpc_url  # or server_manager.host_to_http_url"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def triton_fn(model_name, host_to_url):\n",
    "    import socket\n",
    "    import numpy as np\n",
    "    from pytriton.client import ModelClient\n",
    "\n",
    "    url = host_to_url[socket.gethostname()]\n",
    "    print(f\"Connecting to Triton model {model_name} at {url}.\")\n",
    "\n",
    "    def infer_batch(inputs):\n",
    "        with ModelClient(url, model_name, inference_timeout_s=500) as client:\n",
    "            flattened = np.squeeze(inputs).tolist()\n",
    "            # Encode batch\n",
    "            encoded_batch = [[text.encode(\"utf-8\")] for text in flattened]\n",
    "            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\n",
    "            # Run inference\n",
    "            result_data = client.infer_batch(encoded_batch_np)\n",
    "            result_data = np.squeeze(result_data[\"outputs\"], -1)\n",
    "            return result_data\n",
    "        \n",
    "    return infer_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "generate = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_grpc_url),\n",
    "                             return_type=StringType(),\n",
    "                             input_tensor_shapes=[[1]],\n",
    "                             batch_size=8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load DataFrame\n",
    "\n",
    "We'll parallelize over a small set of prompts for demonstration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = spark.read.parquet(data_path).limit(64).repartition(8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Run Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Stage 10:=====================>                                    (3 + 5) / 8]\r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 10.5 ms, sys: 6.63 ms, total: 17.1 ms\n",
      "Wall time: 23.7 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# first pass caches model/fn\n",
    "preds = df.withColumn(\"outputs\", generate(col(\"prompt\")))\n",
    "results = preds.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Stage 16:=====================>                                    (3 + 5) / 8]\r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 8.1 ms, sys: 4.47 ms, total: 12.6 ms\n",
      "Wall time: 21.7 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"outputs\", generate(col(\"prompt\")))\n",
    "results = preds.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Q: <|im_start|>system\n",
      "You are a knowledgeable AI assistant. Your job is to create a 2-3 sentence summary \n",
      "of a research abstract that captures the main objective, methodology, and key findings, using clear \n",
      "language while preserving technical accuracy and quantitative results.<|im_end|>\n",
      "<|im_start|>user\n",
      "Oral health promotion evaluation--time for development.\n",
      "Increasing emphasis is now being placed upon the evaluation of health service interventions to demonstrate their effects. A series of effectiveness reviews of the oral health education and promotion literature has demonstrated that many of these interventions are poorly and inadequately evaluated. It is therefore difficult to determine the effectiveness of many interventions. Based upon developments from the field of health promotion research this paper explores options for improving the quality of oral health promotion evaluation. It is essential that the methods and measures used in the evaluation of oral health promotion are appropriate to the intervention. For many oral health promotion interventions clinical measures and methods of evaluation may not be appropriate. This paper outlines an evaluation framework which can be used to assess the range of effects of oral health promotion programmes. Improving the quality of oral health promotion evaluation is a shared responsibility between researchers and those involved in the provision of programmes. The provision of adequate resources and training are essential requirements for this to be successfully achieved.<|im_end|>\n",
      "<|im_start|>assistant\n",
      " \n",
      "\n",
      "A: This research aims to improve the evaluation of oral health promotion programs by developing an appropriate framework. It explores how methods and measures should align with the specific nature of these interventions, emphasizing that both researchers and program providers must collaborate to ensure adequate resources and training are available for high-quality evaluations. \n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(f\"Q: {results[0].prompt} \\n\")\n",
    "print(f\"A: {results[0].outputs} \\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Shut down server on each executor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-16 11:51:42,365 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-16 11:51:47,609 - INFO - Sucessfully stopped 1 servers.                 \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[True]"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "server_manager.stop_servers()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n",
    "    spark.stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "spark-dl-torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
